from .base import DownloadData
from torchvision import datasets



CIFAR10_CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


class Cifar10(DownloadData):
    def __init__(self, batch_size=128):
        super().__init__(batch_size)

    # def train_transform(self):
    #     train_transforms = transforms.Compose([
    #         transforms.RandomVerticalFlip(),  # 随机水平翻转
    #         transforms.ToTensor(),
    #     ])
    #     return train_transforms

    def _download(self):
        self.full = datasets.CIFAR10(self.data_dir, train=True, download=True)
        self.test = datasets.CIFAR10(self.data_dir, train=False, download=True)


